LSTM

一种常用的 循环神经网络(RNN) 模块,用于处理具有时序依赖特征的数据(如语音、文本、时间序列等)。每个时间步的公式化描述如下。

\[\begin{split}\begin{aligned} i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) && \text{(输入门)} \\[6pt] f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) && \text{(遗忘门)} \\[6pt] g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) && \text{(候选状态)} \\[6pt] o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) && \text{(输出门)} \\[6pt] c_t &= f_t \odot c_{t-1} + i_t \odot g_t && \text{(细胞状态更新)} \\[6pt] h_t &= o_t \odot \tanh(c_t) && \text{(隐藏状态更新)} \end{aligned}\end{split}\]
  • \(x_t\) : 当前时间步输入向量

  • \(h_{t-1}\) : 上一时间步的隐藏状态

  • \(c_{t-1}\) : 上一时间步的细胞状态

  • \(i_t, f_t, g_t, o_t\) : 四个门(输入门、遗忘门、候选门、输出门)

  • \(W_*\) : 对应的权重矩阵

  • \(b_*\) : 偏置项

  • \(\sigma(\cdot)\) : Sigmoid 函数

  • \(\odot\) : 元素乘

输入:
  • input - 输入序列数据,形状为 \((seq\_len, batch, input\_size)\),即每个时间步的输入特征。

  • params - 静态参数数组,包含 LSTM 网络配置、权重、状态指针等。
    • weight_i - 输入到各门 \((input, forget, cell, output)\) 的权重矩阵,大小为 4 * hidden_size * input_size。

    • weight_h - 上一隐藏状态到各门的权重矩阵,大小为 \(4 * hidden\_size * hidden\_size\)

    • input_bias - 输入部分的偏置项,对应 4 个门的偏置。

    • state_bias - 隐藏状态部分的偏置项(也是 \(4 * hidden\_size\)),与 input_bias 一起求和形成总偏置。

    • hidden_state - 当前批次初始隐藏状态输入( \(h_0\) ),执行后更新为最后时刻的隐藏状态输出( \(h_t\)

    • cell_state - 当前批次初始细胞状态输入( \(c_0\)),执行后更新为最后时刻的细胞状态输出( \(c_t\))。

    • buffer - 临时工作区指针数组(中间计算缓存,如门值、激活结果、临时矩阵等,用于优化性能)。

    • LstmParameter - LSTM 配置参数结构体,包含输入大小、隐藏层维度、序列长度、是否双向等信息。

  • core_mask - 核掩码(仅适用于共享存储版本)。

LstmParameter定义:

 1typedef struct LstmParameter {
 2int input_size_;//每个时间步输入向量的维度(输入特征数)。
 3int hidden_size_;//LSTM 隐藏状态的维度(每个门的内部计算大小)。
 4int project_size_;//投影层输出维度(用于 LSTMP,有则在输出前线性压缩隐藏状态)。
 5int output_size_;//实际输出维度,等于 hidden_size_ 或 project_size_(取决于是否使用投影层)。
 6int seq_len_;//输入序列的时间步数(序列长度)。
 7int batch_;//批次大小(一次处理的样本数量)。
 8// other parameter
 9int output_step_;//指定输出第几个时间步的结果(通常为最后一步或每步)。
10bool bidirectional_;//是否为双向 LSTM(true 表示前向和后向各一层)。
11float zoneout_cell_;//单元状态的 Zoneout 比例(防止过拟合的正则化参数)。
12float zoneout_hidden_;//隐藏状态的 Zoneout 比例(防止过拟合)。
13int input_row_align_;//输入张量的行对齐参数(用于 DMA 或 SIMD 加速的内存对齐)。
14int input_col_align_;//输入张量的列对齐参数。
15int state_row_align_;//状态张量(hidden/cell)的行对齐参数。
16int state_col_align_;//状态张量的列对齐参数。
17int proj_col_align_;//投影层矩阵的列对齐参数。
18bool has_bias_;//是否包含偏置项(true 表示使用 bias)。
19} LstmParameter;

输出:

  • output - 计算结果地址,存放 LSTM 每个时间步输出结果的缓冲区,维度通常为 \((seq\_len, batch, output\_size)\)

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持fp32

  • MT7004 支持fp32、fp16

共享存储版本:

void fp_lstm_s(float *output, const float *input, unsigned long long *params, int core_mask);
void hp_lstm_s(half *output, const half *input, unsigned long long *params, int core_mask);

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <lstm.h>
 4
 5int main(int argc, char* argv[]) {
 6    LstmParameter *lstm_param = (LstmParameter *)0x90000000;
 7    lstm_param->seq_len_ = 4;
 8    lstm_param->batch_ = 1;
 9    lstm_param->input_size_ = 2;
10    lstm_param->hidden_size_ = 3;
11    lstm_param->bidirectional_ = false;
12    float * input = (float *)0xA0000000;
13    float * weight_i = (float *)0xA0001000;
14    float * weight_h = (float *)0xA0003000;
15    float * bias = (float *)0xA0005000;
16    float *hidden_state = (float *)0xA0006000;
17    float *cell_state = (float *)0xA0007000;
18    float *buffer[9];
19    float * packed_input_ = (float *)0xB0000000;
20    buffer[0] = packed_input_;
21    float * gate = (float *)0xB0100000;
22    buffer[1] = gate;
23    float * packed_state = (float *)0xB0200000;
24    buffer[2] = packed_state;
25    float * state_gate = (float *)0xB0300000;
26    buffer[3] = state_gate;
27    float * cell_buffer = (float *)0xB0400000;
28    buffer[4] = cell_buffer;
29    float * hidden_buffer = (float *)0xB0500000;
30    buffer[5] = hidden_buffer;
31    float * packed_output = (float *)0xB0600000;
32    buffer[6] = packed_output;
33    float * left_matrix = (float *)0xB0700000;
34    buffer[7] = left_matrix;
35    float * packed_ptr = (float *)0xB0800000;
36    buffer[8] = packed_ptr;
37    lstm_param->hidden_size_ = 3;
38    lstm_param->output_size_ = 3;
39
40    lstm_param->output_step_ = lstm_param->bidirectional_ ?  2 * lstm_param->batch_ * lstm_param->output_size_
41        : lstm_param->batch_ * lstm_param->output_size_;
42    int weight_segment_num_ = lstm_param->bidirectional_ ? 2 * 4 : 4;
43    int row_tile_ = 12;
44    int col_tile_ = 8;
45    lstm_param->input_row_align_ = UP_ROUND(lstm_param->seq_len_ * lstm_param->batch_, row_tile_);
46    lstm_param->input_col_align_ = UP_ROUND(lstm_param->hidden_size_, col_tile_);
47    int state_row_tile_ = row_tile_;
48    int state_col_tile_ = col_tile_;
49    lstm_param->state_row_align_ = lstm_param->batch_ == 1 ? 1 : UP_ROUND(lstm_param->batch_, state_row_tile_);
50    lstm_param->state_col_align_ =
51            lstm_param->batch_ == 1 ? lstm_param->hidden_size_ : UP_ROUND(lstm_param->hidden_size_, state_col_tile_);
52    lstm_param->proj_col_align_ =
53            lstm_param->batch_ == 1 ? lstm_param->output_size_ : UP_ROUND(lstm_param->output_size_, state_col_tile_);
54    unsigned long long params[9];
55    params[0] = (unsigned long long)weight_i;
56    params[1] = (unsigned long long)weight_h;
57    params[2] = (unsigned long long)input_bias_;//ok
58    params[3] = (unsigned long long)state_bias;
59    params[4] = (unsigned long long)hidden_state;
60    params[5] = (unsigned long long)cell_state;
61    params[6] = (unsigned long long)buffer;
62    params[7] = (unsigned long long)lstm_param;
63    int core_mask = 0x11;
64
65    fp_lstm_s(output, input, params, core_mask);
66        return 0;
67}

私有存储版本:

void fp_lstm_p(float *output, const float *input, unsigned long long *params);
void hp_lstm_p(half *output, const half *input, unsigned long long *params);

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <lstm.h>
 4int main(int argc, char* argv[]) {
 5    LstmParameter *lstm_param = (LstmParameter *)0x10810000;
 6    lstm_param->seq_len_ = 4;
 7    lstm_param->batch_ = 1;
 8    lstm_param->input_size_ = 2;
 9    lstm_param->hidden_size_ = 3;
10    lstm_param->bidirectional_ = false;
11    float * input = (float *)0x10811000;
12    float * weight_i = (float *)0x10812000;
13    float * weight_h = (float *)0x10813000;
14    float * bias = (float *)0x10814000;
15    float *hidden_state = (float *)0x10814800;
16    float *cell_state = (float *)0x10815000;
17    float *buffer[9];
18    float * packed_input_ = (float *)0x10815f00;
19    buffer[0] = packed_input_;
20    float * gate = (float *)0x108160000;
21    buffer[1] = gate;
22    float * packed_state = (float *)0x10816100;
23    buffer[2] = packed_state;
24    float * state_gate = (float *)0x10816200;
25    buffer[3] = state_gate;
26    float * cell_buffer = (float *)0x10816300;
27    buffer[4] = cell_buffer;
28    float * hidden_buffer = (float *)0x10816400;
29    buffer[5] = hidden_buffer;
30    float * packed_output = (float *)0x10816500;
31    buffer[6] = packed_output;
32    float * left_matrix = (float *)0x10816600;
33    buffer[7] = left_matrix;
34    float * packed_ptr = (float *)0x10816700;
35    buffer[8] = packed_ptr;
36    lstm_param->hidden_size_ = 3;
37    lstm_param->output_size_ = 3;
38
39    lstm_param->output_step_ = lstm_param->bidirectional_ ?  2 * lstm_param->batch_ * lstm_param->output_size_
40        : lstm_param->batch_ * lstm_param->output_size_;
41    int weight_segment_num_ = lstm_param->bidirectional_ ? 2 * 4 : 4;
42    int row_tile_ = 12;
43    int col_tile_ = 8;
44    lstm_param->input_row_align_ = UP_ROUND(lstm_param->seq_len_ * lstm_param->batch_, row_tile_);
45    lstm_param->input_col_align_ = UP_ROUND(lstm_param->hidden_size_, col_tile_);
46    int state_row_tile_ = row_tile_;
47    int state_col_tile_ = col_tile_;
48    lstm_param->state_row_align_ = lstm_param->batch_ == 1 ? 1 : UP_ROUND(lstm_param->batch_, state_row_tile_);
49    lstm_param->state_col_align_ =
50            lstm_param->batch_ == 1 ? lstm_param->hidden_size_ : UP_ROUND(lstm_param->hidden_size_, state_col_tile_);
51    lstm_param->proj_col_align_ =
52            lstm_param->batch_ == 1 ? lstm_param->output_size_ : UP_ROUND(lstm_param->output_size_, state_col_tile_);
53    unsigned long long params[9];
54    params[0] = (unsigned long long)weight_i;
55    params[1] = (unsigned long long)weight_h;
56    params[2] = (unsigned long long)input_bias_;//ok
57    params[3] = (unsigned long long)state_bias;
58    params[4] = (unsigned long long)hidden_state;
59    params[5] = (unsigned long long)cell_state;
60    params[6] = (unsigned long long)buffer;
61    params[7] = (unsigned long long)lstm_param;
62    int core_mask = 0x11;
63
64    fp_lstm_s(output, input, params, core_mask);
65    return 0;
66}